# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train a ResNet-50 model on ImageNet."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time

from absl import flags
import absl.logging as _logging  # pylint: disable=unused-import
import tensorflow as tf

from . import imagenet_input  # pylint: disable=relative-beyond-top-level
from . import resnet_model  # pylint: disable=relative-beyond-top-level
from . import hypertune_hook  # pylint: disable=relative-beyond-top-level
from tensorflow.contrib import summary
from tensorflow.contrib.training.python.training import evaluation
from tensorflow.python.estimator import estimator

FLAGS = flags.FLAGS

flags.DEFINE_string(
    'gcp_project', default=None,
    help='Project name .If not specified, we '
    'will attempt to automatically detect the GCE project from metadata.')

# Model specific flags
flags.DEFINE_string(
    'data_dir', default=None,
    help=('The directory where the ImageNet input data is stored. Please see'
          ' the README.md for the expected data format.'))

flags.DEFINE_string(
    'model_dir', default=None,
    help=('The directory where the model and training/evaluation summaries are'
          ' stored.'))

flags.DEFINE_integer(
    'resnet_depth', default=50,
    help=('Depth of ResNet model to use. Must be one of {18, 34, 50, 101, 152,'
          ' 200}. ResNet-18 and 34 use the pre-activation residual blocks'
          ' without bottleneck layers. The other models use pre-activation'
          ' bottleneck layers. Deeper models require more training time and'
          ' more memory and may require reducing --train_batch_size to prevent'
          ' running out of memory.'))

flags.DEFINE_string(
    'mode', default='train_and_eval',
    help='One of {"train_and_eval", "train", "eval"}.')

flags.DEFINE_integer(
    'train_steps', default=112603,
    help=('The number of steps to use for training. Default is 112603 steps'
          ' which is approximately 90 epochs at batch size 1024. This flag'
          ' should be adjusted according to the --train_batch_size flag.'))

flags.DEFINE_integer(
    'train_batch_size', default=32, help='Batch size for training.')

flags.DEFINE_integer(
    'eval_batch_size', default=32, help='Batch size for evaluation.')

flags.DEFINE_integer(
    'steps_per_eval', default=5000,
    help=('Controls how often evaluation is performed. Since evaluation'
          'is fairly expensive, it is advised to evaluate as infrequently'
          'as possible (i.e. up to --train_steps, which evaluates the model'
          'only after finishing the entire training regime).'))

flags.DEFINE_integer(
    'eval_timeout',
    default=None,
    help=(
        'Maximum seconds between checkpoints before evaluation terminates.'))

flags.DEFINE_bool(
    'transpose_input', default=False,
    help=('Skip the host_call which is executed every training step. This is'
          ' generally used for generating training summaries (train loss,'))
flags.DEFINE_bool(
    'skip_host_call', default=False,
    help=('Skip the host_call which is executed every training step. This is'
          ' generally used for generating training summaries (train loss,'
          ' learning rate, etc...). '))

flags.DEFINE_string(
    'data_format', default='channels_last',
    help=('A flag to override the data format used in the model. The value'
          ' is either channels_first or channels_last. To run the network on'
          ' CPU or TPU, channels_last should be used. For GPU, channels_first'
          ' will improve performance.'))

flags.DEFINE_string(
    'export_dir',
    default=None,
    help=('The directory where the exported SavedModel will be stored.'))

flags.DEFINE_string(
    'precision', default='float32',
    help=('Precision to use; one of: {bfloat16, float32}'))

flags.DEFINE_float(
    'base_learning_rate', default=0.1,
    help=('Base learning rate when train batch size is 256.'))

flags.DEFINE_float(
    'momentum', default=0.9,
    help=('Momentum parameter used in the MomentumOptimizer.'))

flags.DEFINE_float(
    'weight_decay', default=1e-4,
    help=('Weight decay coefficiant for l2 regularization.'))

# Dataset constants
LABEL_CLASSES = 1000
NUM_TRAIN_IMAGES = 1281167
NUM_EVAL_IMAGES = 50000

# Learning rate schedule
LR_SCHEDULE = [    # (multiplier, epoch to start) tuples
    (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]

MEAN_RGB = [0.485, 0.456, 0.406]
STDDEV_RGB = [0.229, 0.224, 0.225]


def learning_rate_schedule(current_epoch):
    """Handles linear scaling rule, gradual warmup, and LR decay.

    The learning rate starts at 0, then it increases linearly per step.
    After 5 epochs we reach the base learning rate (scaled to account
      for batch size).
    After 30, 60 and 80 epochs the learning rate is divided by 10.
    After 90 epochs training stops and the LR is set to 0. This ensures
      that we train for exactly 90 epochs for reproducibility.

    Args:
      current_epoch: `Tensor` for current epoch.

    Returns:
      A scaled `Tensor` for current learning rate.
    """
    scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)

    decay_rate = (scaled_lr * LR_SCHEDULE[0][0] *
                  current_epoch / LR_SCHEDULE[0][1])
    for mult, start_epoch in LR_SCHEDULE:
        decay_rate = tf.where(current_epoch < start_epoch,
                              decay_rate, scaled_lr * mult)
    return decay_rate


def resnet_model_fn(features, labels, mode, params):
    """The model_fn for ResNet to be used with Estimator.

    Args:
      features: `Tensor` of batched images.
      labels: `Tensor` of labels for the data samples
      mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
      params: `dict` of parameters passed to the model from the Estimator,
          `params['batch_size']` is always provided and should be used as the
          effective batch size.

    Returns:
      A `EstimatorSpec` for the model
    """
    if isinstance(features, dict):
        features = features['feature']

    if FLAGS.data_format == 'channels_first':
        assert not FLAGS.transpose_input    # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])

    if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    # Normalize the image to zero mean and unit variance.
    features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
    features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

    # This nested function allows us to avoid duplicating the logic which
    # builds the network, for different values of --precision.
    def build_network():
        network = resnet_model.resnet_v1(
            resnet_depth=FLAGS.resnet_depth,
            num_classes=LABEL_CLASSES,
            data_format=FLAGS.data_format)
        return network(
            inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

    logits = build_network()

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']   # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy
    # and L2 regularization.
    one_hot_labels = tf.one_hot(labels, LABEL_CLASSES)
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits, onehot_labels=one_hot_labels)

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + FLAGS.weight_decay * tf.add_n(
        [tf.nn.l2_loss(v) for v in tf.trainable_variables()
         if 'batch_normalization' not in v.name])

    host_call = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        # Compute the current epoch and associated learning rate
        # from global_step.
        global_step = tf.train.get_global_step()
        batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        current_epoch = (tf.cast(global_step, tf.float32) /
                         batches_per_epoch)
        learning_rate = learning_rate_schedule(current_epoch)

        optimizer = tf.train.MomentumOptimizer(
            learning_rate=learning_rate,
            momentum=FLAGS.momentum,
            use_nesterov=True)
        # Batch normalization requires UPDATE_OPS to be added as a
        # dependency to the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)

        if not FLAGS.skip_host_call:
            def host_call_fn(gs, loss, lr, ce):
                """Training host call.
                Creates scalar summaries for training metrics.

                This function is executed on the CPU and should not
                directly reference any Tensors in the rest of the
                `model_fn`. To pass Tensors from the model to the
                `metric_fn`, provide as part of the `host_call`.

                Arguments should match the list of `Tensor` objects
                passed as the second element in the tuple passed to
                `host_call`.

                Args:
                  gs: `Tensor with shape `[batch]` for the global_step
                  loss: `Tensor` with shape `[batch]` for the training loss.
                  lr: `Tensor` with shape `[batch]` for the learning_rate.
                  ce: `Tensor` with shape `[batch]` for the current_epoch.

                Returns:
                  List of summary ops to run on the CPU host.
                """
                gs = gs[0]
                with summary.create_file_writer(FLAGS.model_dir).as_default():
                    with summary.always_record_summaries():
                        summary.scalar('loss', loss[0], step=gs)
                        summary.scalar('learning_rate', lr[0], step=gs)
                        summary.scalar('current_epoch', ce[0], step=gs)

                        return summary.all_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard,
            # the # summary op needs to be run on the host CPU via host_call.
            # host_call expects [batch_size, ...] Tensors, thus reshape to
            # introduce a batch dimension. These Tensors are implicitly
            # concatenated to [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            loss_t = tf.reshape(loss, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])

            host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:
        predictions = tf.argmax(logits, axis=1)
        top_1_accuracy = tf.metrics.accuracy(labels, predictions)
        in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
        # Give the metric a name to make it eaiser to retrieve in
        # SessionRunHook.
        top_5_accuracy = tf.metrics.mean(in_top_5, name='top_5_accuracy')
        eval_metrics = {
            'top_1_accuracy': top_1_accuracy,
            'top_5_accuracy': top_5_accuracy,
        }

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metrics)


def main(unused_argv):
    trial_id = os.environ.get('CLOUD_ML_TRIAL_ID', 0)
    if trial_id:
        if FLAGS.export_dir:
            FLAGS.export_dir = os.path.join(
                FLAGS.export_dir, 'trial_{}'.format(trial_id))
        FLAGS.model_dir = os.path.join(
            FLAGS.model_dir, 'trial_{}'.format(trial_id))

    resnet_classifier = tf.estimator.Estimator(
        model_fn=resnet_model_fn,
        params={'batch_size': FLAGS.train_batch_size})

    assert FLAGS.precision == 'bfloat16' or FLAGS.precision == 'float32', (
        'Invalid value for --precision flag; must be bfloat16 or float32.')
    tf.logging.info('Precision: %s', FLAGS.precision)
    use_bfloat16 = FLAGS.precision == 'bfloat16'

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train, imagenet_eval = [imagenet_input.ImageNetInput(
        is_training=is_training, data_dir=FLAGS.data_dir, transpose_input=False,
        use_bfloat16=use_bfloat16) for is_training in [True, False]]

    if FLAGS.mode == 'eval':
        eval_steps = NUM_EVAL_IMAGES // FLAGS.eval_batch_size

        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                # This time will include compilation time
                start_timestamp = time.time()
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d' %
                                (eval_results, elapsed_time))

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d'
                        % current_step)
                    break

            except tf.errors.NotFoundError:
                tf.logging.info(
                    'Checkpoint %s no longer exists, '
                    'skipping checkpoint' % ckpt)

    else:   # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
        batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                        ' step %d.' % (FLAGS.train_steps,
                                       FLAGS.train_steps / batches_per_epoch,
                                       current_step))

        start_timestamp = time.time()  # This time will include compilation time
        if FLAGS.mode == 'train':
            tf.logging.info('Training for trial_{}'.format(trial_id))
            resnet_classifier.train(
                input_fn=imagenet_train.input_fn, max_steps=FLAGS.train_steps)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written
                # to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                tf.logging.info('Training for trial_{}'.format(trial_id))
                resnet_classifier.train(
                    input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
                current_step = next_checkpoint

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size,
                # some images may be consistently excluded modulo the
                # batch size.
                tf.logging.info('Starting to evaluate.')
                tf.logging.info('Evaluating for trial_{}'.format(trial_id))

                evaluation_hooks = [hypertune_hook.HypertuneHook(
                    'top_5_accuracy')]
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=NUM_EVAL_IMAGES // FLAGS.eval_batch_size,
                    hooks=evaluation_hooks)
                tf.logging.info('Eval results: %s' % eval_results)

        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info('Finished training up to step %d. Elapsed seconds %d.' %
                        (FLAGS.train_steps, elapsed_time))

        if FLAGS.export_dir is not None:
            # The guide to serve a exported TensorFlow model is at:
            #    https://www.tensorflow.org/serving/serving_basic
            tf.logging.info('Starting to export model.')
            resnet_classifier.export_savedmodel(
                export_dir_base=FLAGS.export_dir,
                serving_input_receiver_fn=imagenet_input.image_serving_input_fn)

if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.app.run()
